{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine tuning a model\n",
    "\n",
    "Fine tuning is a form of transfer learning: when you only have a small labeled dataset for a specific task, you can pick up a model trained for a different task and fine-tune it for your specific dataset. These pre-trained models are usually trained on much larger datasets, which helps improving performance. \n",
    "\n",
    "In this tutorial we'll look into how to pick up a pre-trained model and fine tune it for a different task. In part (1) we'll train a model and save it to a checkpoint file. In part (2), we'll load the checkpoint file and run the fine-tuning. Feel free to skip part (1) if you already have a checkpoint file to begin with."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we get started, let us enable INFO level logging so that we can monitor the progress of our runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Training a model\n",
    "Let us begin by pre-training a model using a head with 1000 classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want to train for 4 epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using synthetic train and test datasets for this example. The transforms used are from torchvision and are applied to the input value in the sample (rather than the target)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.dataset import SyntheticImageDataset\n",
    "\n",
    "train_dataset = SyntheticImageDataset.from_config({\n",
    "    \"batchsize_per_replica\": 32,\n",
    "    \"num_samples\": 200,\n",
    "    \"crop_size\": 224,\n",
    "    \"class_ratio\": 0.5,\n",
    "    \"seed\": 0,\n",
    "    \"use_shuffle\": True,\n",
    "    \"transforms\": [{\n",
    "        \"name\": \"apply_transform_to_key\",\n",
    "        \"transforms\": [\n",
    "            {\"name\": \"ToTensor\"},\n",
    "            {\"name\": \"Normalize\", \"mean\": [0.485, 0.456, 0.406], \"std\": [0.229, 0.224, 0.225]}\n",
    "        ],\n",
    "        \"key\": \"input\"\n",
    "    }]\n",
    "})\n",
    "test_dataset = SyntheticImageDataset.from_config({\n",
    "    \"batchsize_per_replica\": 32,\n",
    "    \"num_samples\": 200,\n",
    "    \"crop_size\": 224,\n",
    "    \"class_ratio\": 0.5,\n",
    "    \"seed\": 0,\n",
    "    \"use_shuffle\": False,\n",
    "    \"transforms\": [{\n",
    "        \"name\": \"apply_transform_to_key\",\n",
    "        \"transforms\": [\n",
    "            {\"name\": \"ToTensor\"},\n",
    "            {\"name\": \"Normalize\", \"mean\": [0.485, 0.456, 0.406], \"std\": [0.229, 0.224, 0.225]}\n",
    "        ],\n",
    "        \"key\": \"input\"\n",
    "    }]\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us create a ResNet 50 model now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.models import ResNet\n",
    "\n",
    "model = ResNet.from_config({\n",
    "    \"num_blocks\": [3, 4, 6, 3],\n",
    "    \"small_input\": False,\n",
    "    \"zero_init_bn_residuals\": True\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will create a head with 1000 classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.heads import FullyConnectedHead\n",
    "\n",
    "head = FullyConnectedHead(unique_id=\"default_head\", num_classes=1000, in_plane=2048)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us attach the head to the final block of the model.\n",
    "\n",
    "For ResNet 50, we want to attach to the `3`rd block in the `4`th layer (based on `[3, 4, 6, 3]`). The blocks use 0 indexing, so this maps to `\"block3-2\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.set_heads({\"block3-2\": [head]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use a cross entropy loss from Pytorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.nn.modules.loss import CrossEntropyLoss\n",
    "\n",
    "loss = CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the optimizer, we will be using SGD."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.optim import build_optimizer\n",
    "\n",
    "\n",
    "optimizer = build_optimizer({\n",
    "    \"name\": \"sgd\",\n",
    "    \"param_schedulers\": {\"lr\": {\"name\": \"step\", \"values\": [0.1, 0.01]}},\n",
    "    \"weight_decay\": 1e-4,\n",
    "    \"momentum\": 0.9,\n",
    "    \"num_epochs\": num_epochs\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want to track the top-1 and top-5 accuracies of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.meters import AccuracyMeter\n",
    "\n",
    "meters = [AccuracyMeter(topk=[1, 5])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a directory to save the checkpoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "pretrain_checkpoint_dir = f\"/tmp/checkpoint_{time.time()}\"\n",
    "os.mkdir(pretrain_checkpoint_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add `LossLrMeterLoggingHook` to monitor the loss and `CheckpointHook` to save the checkpoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook\n",
    "\n",
    "hooks = [\n",
    "    LossLrMeterLoggingHook(),\n",
    "    CheckpointHook(pretrain_checkpoint_dir, input_args={})\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have all the components ready to setup our pre-training task which trains for 4 epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.tasks import ClassificationTask\n",
    "\n",
    "pretrain_task = (\n",
    "    ClassificationTask()\n",
    "    .set_num_epochs(num_epochs)\n",
    "    .set_loss(loss)\n",
    "    .set_model(model)\n",
    "    .set_optimizer(optimizer)\n",
    "    .set_meters(meters)\n",
    "    .set_hooks(hooks)\n",
    "    .set_dataset(train_dataset, \"train\")\n",
    "    .set_dataset(test_dataset, \"test\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us train using a local trainer instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.trainer import LocalTrainer\n",
    "\n",
    "trainer = LocalTrainer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can start training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0502 162930.244 local_trainer.py:24] Using CPU\n",
      "I0502 162930.254 loss_lr_meter_logging_hook.py:38] Starting training. Task: <classy_vision.tasks.classification_task.ClassificationTask object at 0x7f993fa69a10>\n",
      "I0502 165422.367 loss_lr_meter_logging_hook.py:76] Synced meters: [0] train phase 0 (100.00% done), loss: 1.9421, meters: [accuracy_meter(top_1=0.670000,top_5=0.840000)]\n",
      "I0502 165422.368 checkpoint_hook.py:78] Saving checkpoint to '/tmp/checkpoint_1588462166.8579776'...\n",
      "I0502 165611.635 loss_lr_meter_logging_hook.py:76] Synced meters: [0] test phase 0 (100.00% done), loss: 14.6478, meters: [accuracy_meter(top_1=0.535000,top_5=1.000000)]\n",
      "I0502 172233.979 loss_lr_meter_logging_hook.py:76] Synced meters: [0] train phase 1 (100.00% done), loss: 0.9628, meters: [accuracy_meter(top_1=0.910000,top_5=1.000000)]\n",
      "I0502 172233.980 checkpoint_hook.py:78] Saving checkpoint to '/tmp/checkpoint_1588462166.8579776'...\n",
      "I0502 172425.138 loss_lr_meter_logging_hook.py:76] Synced meters: [0] test phase 1 (100.00% done), loss: 0.0000, meters: [accuracy_meter(top_1=1.000000,top_5=1.000000)]\n",
      "I0502 174927.788 loss_lr_meter_logging_hook.py:76] Synced meters: [0] train phase 2 (100.00% done), loss: 0.0000, meters: [accuracy_meter(top_1=1.000000,top_5=1.000000)]\n",
      "I0502 174927.789 checkpoint_hook.py:78] Saving checkpoint to '/tmp/checkpoint_1588462166.8579776'...\n",
      "I0502 175116.757 loss_lr_meter_logging_hook.py:76] Synced meters: [0] test phase 2 (100.00% done), loss: 0.0000, meters: [accuracy_meter(top_1=1.000000,top_5=1.000000)]\n",
      "I0502 181608.966 loss_lr_meter_logging_hook.py:76] Synced meters: [0] train phase 3 (100.00% done), loss: 0.0000, meters: [accuracy_meter(top_1=1.000000,top_5=1.000000)]\n",
      "I0502 181608.967 checkpoint_hook.py:78] Saving checkpoint to '/tmp/checkpoint_1588462166.8579776'...\n",
      "I0502 181757.168 loss_lr_meter_logging_hook.py:76] Synced meters: [0] test phase 3 (100.00% done), loss: 0.0000, meters: [accuracy_meter(top_1=1.000000,top_5=1.000000)]\n"
     ]
    }
   ],
   "source": [
    "trainer.train(pretrain_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Fine-tuning the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The original model was trained for 1000 classes. Let's fine-tune it for a problem with only 2 classes. To keep things fast we'll run a single epoch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can re-use the same synthetic datasets as before."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us again create a ResNet 50 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.models import ResNet\n",
    "\n",
    "model = ResNet.from_config({\n",
    "    \"num_blocks\": [3, 4, 6, 3],\n",
    "    \"small_input\": False,\n",
    "    \"zero_init_bn_residuals\": True\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For fine tuning, we will create a head with just 2 classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.heads import FullyConnectedHead\n",
    "\n",
    "head = FullyConnectedHead(unique_id=\"default_head\", num_classes=2, in_plane=2048)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us attach the head to the final block of the model, like before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.set_heads({\"block3-2\": [head]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the optimizer, we will be using RMSProp this time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.optim import build_optimizer\n",
    "\n",
    "\n",
    "optimizer = build_optimizer({\n",
    "    \"name\": \"rmsprop\",\n",
    "    \"param_schedulers\": {\"lr\": {\"name\": \"step\", \"values\": [0.1, 0.01]}},\n",
    "    \"weight_decay\": 1e-4,\n",
    "    \"momentum\": 0.9,\n",
    "    \"alpha\": 0.9,\n",
    "    \"eps\": 1e-3,\n",
    "    \"num_epochs\": num_epochs\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want to track the top-1 accuracy of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.meters import AccuracyMeter\n",
    "\n",
    "meters = [AccuracyMeter(topk=[1])]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create a new directory to save the checkpoints for our fine tuning run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "fine_tuning_checkpoint_dir = f\"/tmp/checkpoint_{time.time()}\"\n",
    "os.mkdir(fine_tuning_checkpoint_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hooks are also the same as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.hooks import CheckpointHook, LossLrMeterLoggingHook\n",
    "\n",
    "hooks = [\n",
    "    LossLrMeterLoggingHook(),\n",
    "    CheckpointHook(fine_tuning_checkpoint_dir, input_args={})\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can setup our fine tuning task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from classy_vision.tasks import FineTuningTask\n",
    "\n",
    "fine_tuning_task = (\n",
    "    FineTuningTask()\n",
    "    .set_num_epochs(num_epochs)\n",
    "    .set_loss(loss)\n",
    "    .set_model(model)\n",
    "    .set_optimizer(optimizer)\n",
    "    .set_meters(meters)\n",
    "    .set_hooks(hooks)\n",
    "    .set_dataset(train_dataset, \"train\")\n",
    "    .set_dataset(test_dataset, \"test\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since this is a fine tuning task, there are some other configurations which need to be done."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We don't want to re-train the trunk, so we will be freezing it. This is optional."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<classy_vision.tasks.fine_tuning_task.FineTuningTask object at 0x7f993450e690>"
      ]
     },
     "execution_count": 23,
     "metadata": {
      "bento_obj_id": "140295984440976"
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fine_tuning_task.set_freeze_trunk(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want to start training the heads from scratch, so we will be resetting them. This is required in this example since the pre-trained heads are not compatible with the heads in fine tuning (they have different number of classes). Otherwise, this is also optional."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<classy_vision.tasks.fine_tuning_task.FineTuningTask object at 0x7f993450e690>"
      ]
     },
     "execution_count": 24,
     "metadata": {
      "bento_obj_id": "140295984440976"
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fine_tuning_task.set_reset_heads(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to give our task the path to our pre-trained checkpoint, which it'll need to start fine-tuning on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0502 181800.442 util.py:483] Attempting to load checkpoint from /tmp/checkpoint_1588462166.8579776/checkpoint.torch\n",
      "I0502 181800.639 util.py:488] Loaded checkpoint from /tmp/checkpoint_1588462166.8579776/checkpoint.torch\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<classy_vision.tasks.fine_tuning_task.FineTuningTask object at 0x7f993450e690>"
      ]
     },
     "execution_count": 25,
     "metadata": {
      "bento_obj_id": "140295984440976"
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fine_tuning_task.set_pretrained_checkpoint(pretrain_checkpoint_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us fine tune!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0502 181800.646 local_trainer.py:24] Using CPU\n",
      "I0502 181801.040 util.py:510] Model state load successful\n",
      "I0502 181801.042 loss_lr_meter_logging_hook.py:38] Starting training. Task: <classy_vision.tasks.fine_tuning_task.FineTuningTask object at 0x7f993450e690>\n",
      "I0502 181950.398 loss_lr_meter_logging_hook.py:76] Synced meters: [0] train phase 0 (100.00% done), loss: 0.1105, meters: [accuracy_meter(top_1=0.895000)]\n",
      "I0502 181950.399 checkpoint_hook.py:78] Saving checkpoint to '/tmp/checkpoint_1588468680.417235'...\n",
      "I0502 182137.554 loss_lr_meter_logging_hook.py:76] Synced meters: [0] test phase 0 (100.00% done), loss: 0.0000, meters: [accuracy_meter(top_1=1.000000)]\n"
     ]
    }
   ],
   "source": [
    "trainer.train(fine_tuning_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Conclusion\n",
    "\n",
    "In this tutorial, we learned how to load a pre-trained model in Classy Vision and how to fine-tune it for a different task. We did that by using the `FineTuningTask` abstraction, which lets you load the pretrained model weights, attaching a new head to the model and optionally freeze the weights of the original model. \n",
    "\n",
    "To learn more about about fine-tuning, check out our documentation for [FineTuningTask](https://classyvision.ai/api/tasks.html#classy_vision.tasks.FineTuningTask) and [ClassyHead](https://classyvision.ai/api/heads.html)"
   ]
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
